{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "",
   "id": "1f727b084080afd8"
  },
  {
   "cell_type": "code",
   "id": "6af2c31b74d877cb",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-04-25T07:50:31.730890Z",
     "start_time": "2024-04-25T07:50:31.261368Z"
    }
   },
   "source": [
    "# 导入必要的包\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import dataloader\n",
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "class Net(nn.Module):  # 定义了一个名为 Net 的类，它继承自 nn.Module\n",
    "    def __init__(self):  # 构造函数，3x32x32的卷积层，2x2的池化层，64个全连接层\n",
    "        super(Net, self).__init__()  # 调用父类的构造函数\n",
    "\n",
    "        # 定义网络结构\n",
    "        self.conv1 = nn.Conv2d(3,6,3)  # 第一个卷积层，输入通道为3，输出通道为6，卷积核大小为3x3\n",
    "        self.conv2 = nn.Conv2d(6,16,3)  # 第二个卷积层，输入通道为6，输出通道为16，卷积核大小为3x3\n",
    "        self.fc1 = nn.Linear(16*28*28,512)  # 第一个全连接层，输入特征数为16*28*28，输出特征数为512\n",
    "        self.fc2 = nn.Linear(512,64)  # 第二个全连接层，输入特征数为512，输出特征数为64\n",
    "        self.fc3 = nn.Linear(64,10)  # 第三个全连接层，输入特征数为64，输出特征数为10\n",
    "\n",
    "    def forward(self, x):\n",
    "        # 定义前向传播过程\n",
    "        x=self.conv1(x)  # 卷积层 1\n",
    "        x=F.relu(x)  # 使用 ReLU 激活函数\n",
    "\n",
    "        x=self.conv2(x)  # 卷积层 2\n",
    "        x=F.relu(x)  # 使用 ReLU 激活函数\n",
    "\n",
    "        x=x.view(-1,16*28*28)  # 对输入进行展平操作\n",
    "        x=self.fc1(x)  # 全连接层 1\n",
    "        x=F.relu(x)  # 使用 ReLU 激活函数\n",
    "\n",
    "        x=self.fc2(x)  # 全连接层 2\n",
    "        x=F.relu(x)  # 使用 ReLU 激活函数\n",
    "\n",
    "        x=self.fc3(x)  # 全连接层 3\n",
    "        return x  # 返回输出\n",
    "\n",
    "\n",
    "net = Net()  # 创建一个名为net的神经网络实例\n",
    "print(net)  # 打印神经网络实例"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Net(\n",
      "  (conv1): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))\n",
      "  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))\n",
      "  (fc1): Linear(in_features=12544, out_features=512, bias=True)\n",
      "  (fc2): Linear(in_features=512, out_features=64, bias=True)\n",
      "  (fc3): Linear(in_features=64, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "execution_count": 9
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-25T07:50:35.405170Z",
     "start_time": "2024-04-25T07:50:34.252972Z"
    }
   },
   "cell_type": "code",
   "source": [
    "transform = transforms.Compose([transforms.ToTensor(),  # 将 PIL 图像或 numpy.ndarray 转换为张量\n",
    "                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])  # 标准化图像的张量\n",
    "\n",
    "# 构建训练数据集\n",
    "trainset = torchvision.datasets.CIFAR10(root='./data', train=True,  # 指定训练数据集，如果不存在则从互联网上下载\n",
    "                                        download=True, transform=transform)  # 使用前面定义的转换函数了，如果是私人数据集就把cifar10以及后面改了\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,  # 创建训练数据的数据加载器，每个批次包含4个样本\n",
    "                                          shuffle=True,  # 对数据进行洗牌\n",
    "                                          num_workers=2)  # 使用多个子进程来加载数据，加快数据加载速度\n",
    "\n",
    "# 构建测试数据集\n",
    "testset = torchvision.datasets.CIFAR10(root='./data', train=False,  # 指定测试数据集，如果不存在则从互联网上下载\n",
    "                                       download=True, transform=transform)  # 使用前面定义的转换函数\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=4,  # 创建测试数据的数据加载器，每个批次包含4个样本\n",
    "                                         shuffle=False,  # 不对数据进行洗牌\n",
    "                                         num_workers=2)  # 使用多个子进程来加载数据，加快数据加载速度"
   ],
   "id": "823bd94fedcd95de",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "execution_count": 10
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-25T07:25:37.037743Z",
     "start_time": "2024-04-25T07:25:36.352940Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import torch.optim as optim\n",
    "criterion = nn.CrossEntropyLoss()  # 定义损失函数为交叉熵\n",
    "optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)  # 定义优化器为随机梯度下降，学习率为0.1\n"
   ],
   "id": "43b70704143a1523",
   "outputs": [],
   "execution_count": 2
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-25T07:58:32.566055Z",
     "start_time": "2024-04-25T07:58:32.565956Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# mini batch训练神经网络\n",
    "for epoch in range(2):\n",
    "    for i, data in enumerate(trainloader):\n",
    "        images, labels = data\n",
    "        outputs = net(images)  # 前向传播\n",
    "        loss = criterion(outputs, labels)  # 计算损失\n",
    "        optimizer.zero_grad()  # 梯度清零\n",
    "        loss.backward()  # 反向传播\n",
    "        optimizer.step()  # 更新参数\n",
    "        \n",
    "        if(i%1000==0):\n",
    "            print('Epoch: %d, Step: /%d, Loss: %.3f' %(epoch, i, loss.item()))"
   ],
   "id": "3056698572495e7d",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-25T08:31:05.450756Z",
     "start_time": "2024-04-25T08:30:44.380009Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 测试模型,需要知道准确率，一共有多少个正确的判断除以总共的样本数\n",
    "correct = 0\n",
    "total = 0\n",
    "with torch.no_grad():\n",
    "    for data in testloader:\n",
    "        images, labels = data\n",
    "        outputs = net(images)     # 把图像给到神经网络，得到输出\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum()\n",
    "\n",
    "print('准确率：',float(correct) / total)\n"
   ],
   "id": "b6883d479bae0b85",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "准确率： 0.0998\n"
     ]
    }
   ],
   "execution_count": 13
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# 保存模型，包含所有参数\n",
    "torch.save(net.state_dict(), 'net.pth')"
   ],
   "id": "76aab9baa9502985"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# 把刚才保存的模型读取到新的网络中，并测试准确率\n",
    "net2.load_state_dict(torch.load('net.pth'))"
   ],
   "id": "ca0498ef642ad8cf"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-26T02:55:13.818619Z",
     "start_time": "2024-04-26T02:54:53.480751Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 加载读取模型，并使用\n",
    "net2 = Net()\n",
    "correct = 0\n",
    "total = 0\n",
    "with torch.no_grad():\n",
    "    for data in testloader:\n",
    "        images, labels = data\n",
    "        outputs = net2(images)     # 把新导入的图像给到神经网络，得到输出\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum()\n",
    "\n",
    "print('准确率：',float(correct) / total)"
   ],
   "id": "b8b6ce4c3ccf3720",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "准确率： 0.0707\n"
     ]
    }
   ],
   "execution_count": 20
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
